'''
Created on 2012-7-25

@author: Xindong
'''

import heapq
import itertools


class PriorityQueue:
    '''
    Priority queue by binary heap
    '''
    def __init__(self):
        self.heap = []
        self.value_map = {}
        self.REMOVED = '<removed-task>'
        self.counter = itertools.count()
    
    def add(self, val, priority=0):
        '''
        Complexity: 
            value mapping: O(logN)
            push to binary heap: O(logN)
        '''
        if val in self.value_map:
            self.remove(val)
        count = next(self.counter)
        entry = [priority, count, val]
        self.value_map[val] = entry
        heapq.heappush(self.heap, entry)
    
    def remove(self, val):
        '''
        Amortized complexity: O(logN)
        '''
        if val not in self.value_map:
            raise Exception(str(val) + " not in the heap")
        entry = self.value_map.pop(val)
        entry[-1] = self.REMOVED
    
    def pop(self):
        '''
        Complexity: 
            value mapping: O(logN)
            pop from binary heap: O(logN)
        '''
        while self.heap:
            _priority, _count, val = heapq.heappop(self.heap)
            if val is not self.REMOVED:
                del self.value_map[val]
                return val
        return None


class DisjointSet:
    '''
    The disjoint set with path compression
    Amortized complexity: O(loglogN) < 5
    '''
    def __init__(self, elements):
        self.parent = {val: val for val in elements}
        self.card   = {val: 1    for val in elements}
    
    def find_root(self, x):
        r = x
        p = self.parent[r]
        while r != p:
            r = p
            p = self.parent[r]
        while x != r:
            y = self.parent[x]
            self.parent[x] = r
            x = y
        return r
    
    def union(self, x, y):
        r_x = self.find_root(x)
        r_y = self.find_root(y)
        if self.card[r_x] < self.card[r_y]:
            self.parent[r_y] = r_x
            self.card[r_x] += self.card[r_y]
        else:
            self.parent[r_x] = r_y
            self.card[r_y] += self.card[r_x]
    
    def to_sets(self):
        sets = dict()
        for val in self.parent:
            r = self.find_root(val)
            if r in sets:
                sets[r].add(val)
            else:
                sets[r] = {val}
        return sets


if __name__ == "__main__":
    ds = DisjointSet([1,2,3,4,5,6,7,8,9])
    ds.union(2, 5)
    ds.union(3, 4)
    print(ds.to_sets())
    ds.union(8, 9)
    ds.union(5, 9)
    print(ds.to_sets())
    ds.union(3, 8)
    print(ds.to_sets())
    
    q = PriorityQueue()
    q.add("B", 2)
    q.add("E", 5)
    q.add("C", 3)
    q.add("D", 4)
    q.add("A", 1)
    print(q.pop())
    print(q.pop())
    print(q.pop())
    print(q.pop())
    print(q.pop())
    print(q.pop())
    
    q.add("B", 2)
    q.add("E", 5)
    q.add("C", 3)
    q.add("D", 4)
    q.add("A", 1)
    q.remove("C")
    q.remove("A")
    print(q.pop())
    print(q.pop())
    print(q.pop())
    print(q.pop())
    print(q.pop())
    print(q.pop())
    
    q.add("B", 2)
    q.add("E", 5)
    q.add("C", 3)
    q.add("D", 4)
    q.add("A", 1)
    q.add("C", 0)
    print(q.pop())
    print(q.pop())
    print(q.pop())
    print(q.pop())
    print(q.pop())
    print(q.pop())

